"""Evaluate credit assignment effectiveness.
We compare component contribution estimates vs actual ablation drops.
"""
import json
from marl.train import load_dataset
from marl.environments.pipeline_env import PipelineEnvironment
from marl.utils.credit_assignment import CreditAssignment
from marl.environments.ml_components import COMPONENT_MAP
from experiments.utils import load_dataset_safely, seed_everything


def evaluate_credit(dataset: str, pipeline, repeats: int = 3):
    seed_everything(42)
    data, msg = load_dataset_safely(dataset)
    if data is None:
        raise RuntimeError(msg)
    env = PipelineEnvironment(data, available_components=list(COMPONENT_MAP.keys()), max_pipeline_length=8, debug=False)
    credit = CreditAssignment()

    # Evaluate base performance
    base_perf, status = env.evaluate_with_timeout(pipeline, timeout=env.eval_timeout, return_status=True)

    # Compute component credits with adaptive method (simulate training output)
    est_credits = credit.assign_component_credit(pipeline, base_perf, lambda p: env.evaluate_with_timeout(p, timeout=env.eval_timeout))

    # Ground-truth via full ablation
    true_contrib = {}
    for i, comp in enumerate(pipeline[:-1]):  # exclude END_PIPELINE at end
        mod = pipeline[:i] + pipeline[i+1:]
        mod_perf = env.evaluate_with_timeout(mod, timeout=env.eval_timeout)
        true_contrib[comp] = max(0, base_perf - mod_perf)

    # Normalize ground truth
    s = sum(true_contrib.values())
    if s > 0:
        for k in true_contrib:
            true_contrib[k] /= s

    # Align components
    report = []
    for comp in true_contrib:
        report.append({
            'component': comp,
            'true_credit': true_contrib[comp],
            'estimated_credit': est_credits.get(comp, 0.0),
            'abs_error': abs(true_contrib[comp] - est_credits.get(comp, 0.0))
        })

    mae = sum(r['abs_error'] for r in report) / max(1, len(report))
    return {'base_performance': base_perf, 'component_report': report, 'mae': mae}

if __name__ == '__main__':
    import argparse
    import ast
    p = argparse.ArgumentParser()
    p.add_argument('--dataset', default='iris')
    p.add_argument('--pipeline', required=True, help='Python list of component string names including END_PIPELINE')
    args = p.parse_args()
    pipeline = ast.literal_eval(args.pipeline)
    res = evaluate_credit(args.dataset, pipeline)
    print(json.dumps(res, indent=2))
